{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chinese_gpt import TransformerEncoder as Encoder\n",
    "from chinese_gpt import TransformerDecoderLM as Decoder\n",
    "from pytorch_pretrained_bert import BertModel, BertTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encoder()\n",
    "decoder = Decoder()\n",
    "bert_model = BertModel.from_pretrained(\"bert-base-chinese\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_state_dict = encoder.state_dict()\n",
    "bert_state_dict = bert_model.state_dict()\n",
    "\n",
    "for item in encoder_state_dict.keys():\n",
    "    if item in bert_state_dict:\n",
    "        encoder_state_dict[item] = bert_state_dict[item]\n",
    "    else:\n",
    "        print(item)\n",
    "encoder.load_state_dict(encoder_state_dict)\n",
    "torch.save(encoder.state_dict(), \"encoder.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_state_dict = decoder.state_dict()\n",
    "temp_state_dict = torch.load(\"model_state_epoch_62.th\")\n",
    "\n",
    "for item in decoder_state_dict.keys():\n",
    "    if item in temp_state_dict:\n",
    "        decoder_state_dict[item] = temp_state_dict[item]\n",
    "    else:\n",
    "        print(item)\n",
    "        \n",
    "decoder.load_state_dict(decoder_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(decoder.state_dict(), \"decoder.pth\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
